from typing import List

import pandas as pd
import pingouin as pg

from config import Config
from dataProcessing import DataProcessing


class StatisticalCalculations:
    # region trials
    @staticmethod
    def rm_anovas_trials():
        input_df = DataProcessing.get_trials_answers()

        for part in range(0, 4):
            if part != 0:
                keys = [["size", "tilt", "distance"][part - 1]] + ["duration"]
            else:
                keys = []
            keys += ["viewingAngle", "angularSize", "duration"]

            for key in keys:
                output_df = pg.rm_anova(
                    data=input_df[input_df["part"] != part],
                    dv=key,
                    within=["area", "content"],
                    subject="participant",
                    effsize="np2"
                )
                output_df.to_csv(
                    path_or_buf=f"{Config.StatisticsOutputPath}/Stats 01 RM Anova P{part} {key}.csv"
                )

    @staticmethod
    def aggregated_values_trials(aggregations: List[str]):
        input_df = DataProcessing.get_trials_answers()

        for group_keys in [["area"], ["content"], ["area", "content"]]:
            output_df = input_df[group_keys + ["angularSize", "viewingAngle", "duration"]].groupby(
                by=group_keys
            ).agg({
                "angularSize": aggregations,
                "viewingAngle": aggregations,
                "duration": aggregations,
            })
            output_df.to_csv(
                path_or_buf=f"{Config.StatisticsOutputPath}/Stats 02 Aggregate {group_keys}.csv"
            )

        for group_keys in [["area"], ["content"], ["area", "content"]]:
            for i, key in enumerate(["size", "tilt", "distance"]):
                output_df = input_df[input_df["part"] != i + 1][group_keys + [key, "angularSize", "viewingAngle", "duration"]].groupby(
                    by=group_keys
                ).agg({
                    key: aggregations,
                    "angularSize": aggregations,
                    "viewingAngle": aggregations,
                    "duration": aggregations,
                })
                output_df.to_csv(
                    path_or_buf=f"{Config.StatisticsOutputPath}/Stats 03 Aggregate {group_keys} not Part {i + 1}.csv"
                )
    # endregion

    # region questionnaire
    @staticmethod
    def ttests_questionnaire():
        input_df = DataProcessing.get_questionnaire_data()

        input_dict = {
            "participant": [],
            "studyState": [],
        }
        input_dict.update({f"SQ00{q}": [] for q in range(1, 7)})

        for _, row in input_df.iterrows():
            input_dict["participant"].extend([row["participant"]] * 2)
            input_dict["studyState"].extend(["prior", "post"])
            for q in range(1, 7):
                input_dict[f"SQ00{q}"].extend([row[f"HealthConditionPrior[SQ00{q}]"], row[f"HealthConditionAfter[SQ00{q}]"]])
        input_df = pd.DataFrame.from_dict(input_dict)

        for key in [f"SQ00{q}" for q in range(1, 7)]:
            output_df = pg.ttest(
                x=input_df[input_df["studyState"] == "prior"][key],
                y=input_df[input_df["studyState"] == "post"][key],
                paired=True,
            )
            output_df.to_csv(
                path_or_buf=f"{Config.StatisticsOutputPath}/Stats 04 ttest {key}.csv"
            )

    @staticmethod
    def aggregated_values_questionnaire(aggregations: List[str]):
        input_df = DataProcessing.get_questionnaire_data()
        questionnaire_keys = [f"HealthConditionDiff[SQ00{q}]" for q in range(1, 7)] + [f"PreferenceContent[SQ00{q}]" for q in range(1, 3)]

        output_df = input_df[questionnaire_keys].agg({k: aggregations for k in questionnaire_keys})

        output_df.to_csv(
            path_or_buf=f"{Config.StatisticsOutputPath}/Stats 05 Aggregate Questionnaire.csv"
        )

        input_dict = {
            "participant": [],
            "studyState": [],
        }
        input_dict.update({f"SQ00{q}": [] for q in range(1, 7)})

        for _, row in input_df.iterrows():
            input_dict["participant"].extend([row["participant"]] * 2)
            input_dict["studyState"].extend(["prior", "post"])
            for q in range(1, 7):
                input_dict[f"SQ00{q}"].extend([row[f"HealthConditionPrior[SQ00{q}]"], row[f"HealthConditionAfter[SQ00{q}]"]])
        input_df = pd.DataFrame.from_dict(input_dict)

        output_df = input_df.groupby(
            by="studyState"
        ).agg({f"SQ00{q}": aggregations for q in range(1, 7)})
        output_df.to_csv(
            path_or_buf=f"{Config.StatisticsOutputPath}/Stats 06 Aggregate studyState.csv"
        )
    # endregion


if __name__ == '__main__':
    StatisticalCalculations.rm_anovas_trials()
    StatisticalCalculations.aggregated_values_trials(
        aggregations=[  # Here also other aggregation functions can be used, like: median, sem, mad, ...
            "describe"
        ]
    )

    StatisticalCalculations.ttests_questionnaire()
    StatisticalCalculations.aggregated_values_questionnaire(
        aggregations=[  # Here also other aggregation functions can be used, like: median, sem, mad, ...
            "describe"
        ]
    )
